{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pynvml import *\n",
    "\n",
    "nvmlInit()\n",
    "vram = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(1)).free/1024.**2\n",
    "print('GPU1 Memory: %dMB' % vram)\n",
    "if vram < 8000:\n",
    "    raise Exception('GPU Memory too low')\n",
    "nvmlShutdown()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import cv2\n",
    "import h5py\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import *\n",
    "from collections import Counter\n",
    "import seaborn as sns\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "import re\n",
    "import time\n",
    "import random\n",
    "\n",
    "from keras.models import *\n",
    "import keras.backend as K\n",
    "from make_parallel import make_parallel\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "characters = u'0123456789()+-*/=君不见黄河之水天上来奔流到海复回烟锁池塘柳深圳铁板烧; '\n",
    "\n",
    "n_len = 51\n",
    "rnn_length = 110\n",
    "n, width, height, n_class, channels = 100000, 900, 81, len(characters), 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def decode(out):\n",
    "    return ''.join([characters[x] for x in out if x < n_class-1 and x > -1])\n",
    "\n",
    "def disp3(index):\n",
    "    s = decode(out[index])\n",
    "    plt.figure(figsize=(16, 4))\n",
    "    plt.imshow(X[index].transpose(1, 0, 2))\n",
    "    plt.title('pred:%s'%s)\n",
    "\n",
    "def disp2(img):\n",
    "    cv2.imwrite('a.png', img)\n",
    "    return Image('a.png')\n",
    "\n",
    "def disp(img, txt=None, first=False):\n",
    "    global index\n",
    "    if first:\n",
    "        index = 1\n",
    "        plt.figure(figsize=(16, 9))\n",
    "    else:\n",
    "        index += 1\n",
    "    plt.subplot(4, 1, index)\n",
    "    if len(img.shape) == 2:\n",
    "        plt.imshow(img, cmap='gray')\n",
    "    else:\n",
    "        plt.imshow(img[:,:,::-1])\n",
    "    if txt:\n",
    "        plt.title(txt)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读取测试集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.zeros((n, width, height, channels), dtype=np.uint8)\n",
    "\n",
    "for i in tqdm(range(n)):\n",
    "    img = cv2.imread('crop_split2_test/%d.png'%i).transpose(1, 0, 2)\n",
    "    a, b, _ = img.shape\n",
    "    X[i, :a, :b] = img"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = '150'\n",
    "\n",
    "base_model = load_model('model_346_split2_4_%s.h5' % z)\n",
    "\n",
    "base_model2 = make_parallel(base_model, 4)\n",
    "y_pred = base_model2.predict(X, batch_size=500, verbose=1)\n",
    "out = K.get_value(K.ctc_decode(y_pred[:,2:], input_length=np.ones(y_pred.shape[0])*rnn_length)[0][0])[:, :n_len]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "ss = map(decode, out)\n",
    "\n",
    "vals = []\n",
    "errs = []\n",
    "errsid = []\n",
    "for i in tqdm(range(100000)):\n",
    "    val = ''\n",
    "    try:\n",
    "        a = ss[i].split(';')\n",
    "        s = a[-1]\n",
    "        for x in a[:-1]:\n",
    "            x, c = x.split('=')\n",
    "            s = s.replace(x, c+'.0')\n",
    "        val = '%.2f' % eval(s)\n",
    "    except:\n",
    "#         disp3(i)\n",
    "        errs.append(ss[i])\n",
    "        errsid.append(i)\n",
    "        ss[i] = ''\n",
    "    \n",
    "    vals.append(val)\n",
    "    \n",
    "with open('result_%s_%d.txt' % (z, len(errs)), 'w') as f:\n",
    "    f.write('\\n'.join(map(' '.join, list(zip(ss, vals)))).encode('utf-8'))\n",
    "\n",
    "print len(errs)\n",
    "print 1-len(errs)/100000."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
